package spark.ml;

import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

import java.util.List;

/**
 * 作者: LDL
 * 说明:
 * 时间: 2015/6/28 15:51
 */
public class PipelineTest {

    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setMaster("local").setAppName("JavaSimpleTextClassificationPipeline");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        SQLContext jsql = new SQLContext(jsc);

        List<LabeledDocument> localTraining = Lists.newArrayList(
                new LabeledDocument(0L, "a b c d e spark", 1.0),
                new LabeledDocument(1L, "b d", 0.0),
                new LabeledDocument(2L, "spark f g h", 1.0),
                new LabeledDocument(3L, "hadoop mapreduce", 0.0));
        DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);

        Tokenizer tokenizer = new Tokenizer()
                .setInputCol("text")
                .setOutputCol("words");
        HashingTF hashingTF = new HashingTF()
                .setNumFeatures(1000)
                .setInputCol(tokenizer.getOutputCol())
                .setOutputCol("features");
        LogisticRegression lr = new LogisticRegression()
                .setMaxIter(10)
                .setRegParam(0.01);
        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{tokenizer, hashingTF, lr});

        PipelineModel model = pipeline.fit(training);

        List<Document> localTest = Lists.newArrayList(
                new Document(4L, "spark i j k"),
                new Document(5L, "l m n"),
                new Document(6L, "mapreduce spark"),
                new Document(7L, "apache hadoop"));
        DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);

        DataFrame predictions = model.transform(test);
        for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
            System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
                    + ", prediction=" + r.get(3));
        }

        jsc.stop();
    }
}
